{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "size = 32\n",
    "embedding_dim = 16\n",
    "num_positive_samples = 3\n",
    "num_negative_samples = 6\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate data\n",
    "torch.manual_seed(1234)\n",
    "# positive\n",
    "subset = torch.randint(size, size=(batch_size, num_positive_samples + 1))\n",
    "src, pos = subset.split((1, num_positive_samples), dim=-1)\n",
    "assert src.size() == (batch_size,  1)\n",
    "assert pos.size() == (batch_size, num_positive_samples)\n",
    "# negative\n",
    "negs = torch.randint(size, size=(batch_size, num_negative_samples))\n",
    "assert negs.size() == (batch_size, num_negative_samples)\n",
    "\n",
    "# embedder\n",
    "embedder = torch.nn.Embedding(size, embedding_dim)\n",
    "embedder.reset_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mrr tensor(0.3159)\n",
      "loss tensor(20.9239)\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    # encode\n",
    "    embedding = embedder(src)\n",
    "    assert embedding.size() == (batch_size,  1, embedding_dim)\n",
    "    pos_embedding = embedder(pos)\n",
    "    assert pos_embedding.size() == (batch_size,  num_positive_samples, embedding_dim)\n",
    "    negs_embedding = embedder(negs)\n",
    "    assert negs_embedding.size() == (batch_size,  num_negative_samples, embedding_dim)\n",
    "    \n",
    "    logits = torch.sum(embedding * pos_embedding, dim=2)\n",
    "    assert logits.shape == (batch_size, num_positive_samples)\n",
    "    negs_logits = torch.sum(embedding * negs_embedding, dim=2)\n",
    "    assert negs_logits.shape == (batch_size, num_negative_samples)\n",
    "    \n",
    "    # compute mrr\n",
    "    mrr_all = torch.cat((negs_logits, logits), dim=-1)\n",
    "    mrr_size = mrr_all.shape[-1]\n",
    "    _, indices_of_ranks = mrr_all.topk(mrr_size)\n",
    "    _, ranks = (-indices_of_ranks).topk(mrr_size)\n",
    "    mrr = ranks[:,-1].float().reciprocal().mean()\n",
    "    print('mrr', mrr)\n",
    "    \n",
    "    # compute loss\n",
    "    pos_loss = F.logsigmoid(logits).sum(dim=-1)\n",
    "    negs_loss = F.logsigmoid(-negs_logits).sum(dim=-1)\n",
    "    loss = -(pos_loss + negs_loss).mean()\n",
    "    print('loss', loss)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
